"""TODO: Add docstring."""

import json
import os

import numpy as np
import pyarrow as pa
from dora import Node

node = Node()

IMAGE_RESIZE_RATIO = float(os.getenv("IMAGE_RESIZE_RATIO", "1.0"))


def extract_bboxes(json_text) -> (np.ndarray, np.ndarray):
    """Extract bounding boxes from a JSON string with markdown markers and return them as a NumPy array.

    Parameters
    ----------
    json_text : str
        JSON string containing bounding box data, including ```json markers.

    Returns
    -------
    np.ndarray: NumPy array of bounding boxes.

    """
    # Ensure all lines are stripped of whitespace and markers
    lines = json_text.strip().splitlines()

    # Filter out lines that are markdown markers
    clean_lines = [line for line in lines if not line.strip().startswith("```")]

    # Join the lines back into a single string
    clean_text = "\n".join(clean_lines)
    # Parse the cleaned JSON text
    try:
        data = json.loads(clean_text)

        # Extract bounding boxes
        bboxes = [item["bbox_2d"] for item in data]
        labels = [item["label"] for item in data]

        return np.array(bboxes), np.array(labels)
    except Exception as _e:  # noqa
        pass
    return None, None


for event in node:
    text = "Put the chocolate in the white plate"
    if event["type"] == "INPUT":
        if event["id"] == "prompt":
            prompt = event["value"][0].as_py()

        elif event["id"] == "text":
            text = event["value"][0].as_py()
            image_id = event["metadata"]["image_id"]

            bboxes, labels = extract_bboxes(text)
            if bboxes is not None and len(bboxes) > 0:
                bboxes = bboxes * int(1 / IMAGE_RESIZE_RATIO)
                unique_labels = np.unique(labels)
                idx = []
                order = []
                for label in unique_labels:
                    if label in prompt:
                        # Get the index of the start of the label in the prompt
                        order.append(prompt.index(label))
                        idx.append(np.where(labels == label)[0][0])

                if len(idx) == 0:
                    continue
                # Reorder idx given the order
                # print(idx, order)
                idx = np.array(idx)[np.argsort(order)].ravel()
                bboxes = bboxes[idx]
                # Check for duplicated box
                if len(np.unique(bboxes, axis=0)) != len(bboxes):
                    print("Duplicated box")
                    continue
                node.send_output(
                    "bbox",
                    pa.array([{"bbox": bboxes.ravel(), "labels": labels[idx]}]),
                    metadata={"encoding": "xyxy", "image_id": image_id},
                )
